from args import *
import os
import time
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from peft import (
    get_peft_model,
    LoCAConfig,
    FourierConfig,
    LoraConfig,
    AdaLoraConfig,
    PeftType
)
from datasets import load_dataset, load_metric
from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig
from tqdm import tqdm

# torch.autograd.set_detect_anomaly(True)

args = get_args()
print(args)

torch.manual_seed(args.seed)
task = args.task
peft_type = args.adapter
device = "cuda"

if peft_type == 'LoRA':
    peft_config = LoraConfig(task_type="SEQ_CLS", 
                             inference_mode=False, 
                             r=args.lora_r, 
                             lora_alpha = args.lora_alpha, 
                             lora_dropout = args.lora_dropout
                             )
elif peft_type == 'LoCA':
    peft_config = LoCAConfig(task_type="SEQ_CLS", 
                         inference_mode=False, 
                         n_frequency = args.n_frequency, 
                         scale = args.scale, 
                         loca_dropout = args.loca_dropout,
                         learn_location_iter = args.learn_location_iter,
                         dct_mode = args.loca_dct_mode
                         )
elif peft_type == 'FourierFT':
    peft_config = FourierConfig(task_type="SEQ_CLS", 
                        inference_mode=False, 
                        n_frequency = args.n_frequency, 
                        scale = args.scale)
elif peft_type == 'DoRA':
    peft_config = LoraConfig(task_type="SEQ_CLS", 
                             inference_mode=False, 
                             r=args.lora_r,
                             use_dora = True, 
                             lora_alpha = args.lora_alpha, 
                             lora_dropout = args.lora_dropout
                             )
elif peft_type == 'AdaLoRA':
    peft_config = AdaLoraConfig(task_type="SEQ_CLS", 
                             inference_mode=False, 
                             target_r=args.lora_r,
                             init_r = int(1.5*args.lora_r),
                             lora_alpha = args.lora_alpha, 
                             deltaT = 10, 
                             lora_dropout = args.lora_dropout
                             )

def log(*pargs):
    path_log = './logs_glue/' + task + '/' + args.model_name_or_path.split("-")[1] + '/bs' + str(args.bs) + 'maxlen' + str(args.max_length) + 'f_lr' + str(args.frequency_lr)+ 'h_lr' + str(args.head_lr) + \
          'num' + str(args.n_frequency) + 'scale' + str(args.scale) + 'seed' + str(args.seed) + '.txt'
    print(path_log)
    with open(path_log, mode = 'a+') as w:
        w.write(" ".join(["{}".format(t) for t in pargs]))
        w.write("\n")

if any(k in args.model_name_or_path for k in ("gpt", "opt", "bloom")):
    padding_side = "left"
else:
    padding_side = "right"

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

datasets = load_dataset("glue", task)
metric = load_metric("glue", task)

if task == "stsb":
    num_labels = 1
else:
    label_list = datasets["train"].features["label"].names
    num_labels = len(label_list)

def tokenize_function(examples):
    # max_length=None => use the model max length (it's actually the default)
    if task == 'sst2' or task == 'cola':
        outputs = tokenizer(examples["sentence"], truncation=True, max_length=args.max_length)
    elif task == 'qnli':
        outputs = tokenizer(examples["question"], examples["sentence"], truncation=True, max_length=args.max_length)
    elif task == 'qqp':
        outputs = tokenizer(examples["question1"], examples["question2"], truncation=True, max_length=args.max_length)
    elif task == 'mnli':
        outputs = tokenizer(examples["premise"], examples["hypothesis"], truncation=True, max_length=args.max_length)
    else:
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=args.max_length)
    return outputs

if task == 'sst2' or task == 'cola':
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence"],
    )
elif task == 'qnli':
    tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=["idx", "question", "sentence"],
    )
elif task == 'qqp':
    tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=["idx", "question1", "question2"],
    )
elif task == 'mnli':
    tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=["idx", "premise", "hypothesis"],
    )
else:
    tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    remove_columns=["idx", "sentence1", "sentence2"],
    )

tokenized_datasets = tokenized_datasets.rename_column("label", "labels")


def collate_fn(examples):
    return tokenizer.pad(examples, padding="longest", return_tensors="pt")


# Instantiate dataloaders.
train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=args.bs)
if task == 'mnli':
    eval_dataloader1 = DataLoader(
    tokenized_datasets["validation_matched"], shuffle=False, collate_fn=collate_fn, batch_size=args.bs)
    eval_dataloader2 = DataLoader(
    tokenized_datasets["validation_mismatched"], shuffle=False, collate_fn=collate_fn, batch_size=args.bs)
    eval_dataloaders = [eval_dataloader1, eval_dataloader2]
else:
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=args.bs)
    eval_dataloaders = [eval_dataloader]

config = AutoConfig.from_pretrained(
    args.model_name_or_path,
    num_labels=num_labels,
    hidden_dropout_prob = args.hidden_dropout_prob,
    finetuning_task=args.task,
    return_dict=True)

model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)


print(peft_config)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

head_param = list(map(id, model.classifier.parameters()))
indices_id = [id(param) for name, param in model.named_parameters() if 'spectrum_indices' in name]
indices_param = filter(lambda p: id(p) in indices_id, model.parameters())
others_param = filter(lambda p: (p.requires_grad == True) and (id(p) not in head_param) and (id(p) not in indices_id), model.parameters()) 

optimizer = AdamW([
    {"params": model.classifier.parameters(), "lr": args.head_lr, "weight_decay": args.weight_decay},
    {"params": indices_param, "lr": args.location_lr, "weight_decay": 0},
    {"params": others_param, "lr": args.frequency_lr, "weight_decay": args.weight_decay},
])

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps= args.warm_step * (len(train_dataloader) * args.num_epochs),
    num_training_steps=(len(train_dataloader) * args.num_epochs),
)

acc_list = []
model.to(device)

if args.eval_steps is None:
    args.eval_steps = len(train_dataloader)

def eval_model(model, eval_dataloaders):
    model.eval()
    for eval_dataloader in eval_dataloaders:
        for step, batch in enumerate(tqdm(eval_dataloader)):
            batch.to(device)
            with torch.no_grad():
                outputs = model(**batch)
            if task == "stsb":
                predictions = outputs.logits
            else:
                predictions = outputs.logits.argmax(dim=-1)
            predictions, references = predictions, batch["labels"]
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

        eval_metric = metric.compute()
        if task == "stsb":
            acc_list.append(eval_metric['pearson'])
            log(f"epoch {epoch}:", eval_metric, ', current_best_pearson:',max(acc_list),'train_loss:',loss)
            print(f"epoch {epoch}:", eval_metric, '\033[32m, current_best_pearson:\033[0m',max(acc_list),'train_loss:',loss)
        elif task == 'cola':
            acc_list.append(eval_metric['matthews_correlation'])
            print(f"epoch {epoch}:", eval_metric, '\033[32m, current_best_corr:\033[0m',max(acc_list),'train_loss:',loss)
            log(f"epoch {epoch}:", eval_metric, ', current_best_corr:',max(acc_list),'train_loss:',loss)
        else:
            acc_list.append(eval_metric['accuracy'])
            print(f"epoch {epoch}:", eval_metric, '\033[32m, current_best_acc:\033[0m',max(acc_list),'train_loss:',loss)
            log(f"epoch {epoch}:", eval_metric, ', current_best_acc:',max(acc_list),'train_loss:',loss)



for epoch in range(args.num_epochs):
    global_step = 0
    model.train()
    for step, batch in enumerate(tqdm(train_dataloader)):
        global_step +=1
        batch.to(device)
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        if global_step % args.eval_steps == 0:
            eval_model(model, eval_dataloaders)
            model.train()